
import numpy as np
import SimpleITK as sitk
import os
import configparser

class SitkDataReader:

    def __init__(self, dir_name):
        self.dir_name = dir_name
        self.files = os.listdir(dir_name)
        self.files.sort()
        self.num_data = len(self.files)
        self.file_objects = [sitk.ReadImage(os.path.join(dir_name, self.files[i])) for i in range(self.num_data)]

    def get_file_obj(self,case_indices=0):
        return self.file_objects[case_indices]

    def get_data(self, case_indices=0):
        array=sitk.GetArrayFromImage(self.file_objects[case_indices])
        array=array.swapaxes(0,-1)
        return array